from mynumpy import *
norm  = autograd.scipy.stats.norm
studt = autograd.scipy.stats.t
gamma = autograd.scipy.stats.gamma

def compute_loglike(logp,ll0=None):
    ff = lambda z : exp(logp(z))
    like,err = scipy.integrate.quad(ff,-inf,inf,limit=2500)
    # if true loglike is given use as a test
    print('error of loglike',err)
    assert(err < 1e-5)
    return log(like)

class target:

    def __init__(self,logp,label,ll=None):
        self.loglike = compute_loglike(logp)
        self.logp    = lambda z : logp(z) - self.loglike
        self.label   = label
        if ll is not None:
            assert(abs(self.loglike-ll)<1e-6)
            #print('error of loglike vs. exact:',self.loglike-ll)

    def p(self,z):
        return exp(self.logp(z))

# returns a function that evaluates log p(z,x)

def target_norm():
    logp = norm.logpdf
    ll = 0.0
    return target(logp,'norm',ll)

def logp_mix(l1,l2,alpha,beta=None):
    if beta is None:
        beta = 1-alpha
    # log(alpha*exp(l1) + (1-alpha)*exp(l2))
    # log(exp(log(alpha))*exp(l1) + exp(log(1-alpha))*exp(l2))
    # log( exp(log(alpha)+l1) + exp(log(1-alpha)+l2) )
    return np.logaddexp(log(alpha)+l1,log(beta)+l2)

def target_imbalanced():
    center = 3.5
    logp1 = lambda z : norm.logpdf(z,loc=center,scale=2) - np.logaddexp(0,-10*(z-center))
    #logp = lambda z : log(norm.pdf(z,scale=3)*(.5+.45*np.sign(z)*np.sin(2*z)**2))
    #logp = lambda z : norm.logpdf(z,scale=3) + np.log(.5+.45*np.sign(z)*np.sin(2*z)**2)
    logp = lambda z : logp_mix(logp1(z), norm.logpdf(z,loc=-2*center,scale=.4),.97)

    return target(logp,'imbalanced')

def target_half():
    logp1 = norm.logpdf
    #logp2 = scipy.stats.halfnorm.logpdf
    logp2 = lambda z : norm.logpdf(z) - log(.5) - 1e300*(z<0)
    def logp(z):
        return logp_mix(logp1(z),logp2(z),.05)   
    return target(logp,'half')

def target_halfhalf():
    logp1 = norm.logpdf
    #logp2 = scipy.stats.halfnorm.logpdf
    logp2 = lambda z : norm.logpdf(z) - log(.5) - 1e300*(z<0)
    def logp3(z):
        return logp_mix(logp1(z),logp2(z),.05)
    def logp(z):
        return logp_mix(logp3(z+1),logp3(z-1),.5)
    return target(logp,'halfhalf')

def target_mix():
    logp1 = lambda z : norm.logpdf(z-2)
    logp2 = lambda z : norm.logpdf(z+2)
    def logp(z):
        return logp_mix(logp1(z),logp2(z),.75)
    ll = 0.0
    return target(logp,'mix',ll)

def target_tallwide():
    logp1 = lambda z : norm.logpdf(z,loc=1,scale=.5)
    logp2 = lambda z : norm.logpdf(z,loc=-1,scale=2)
    def logp(z):
        return logp_mix(logp1(z),logp2(z),.5)
    ll = 0.0
    return target(logp,'tallwide',ll)

def target_saw_sin():
    logp = lambda z : norm.logpdf(z,loc=0,scale=2)  +  sin(3*z)
    return target(logp,'saw_sin')

def target_saw_cos():
    logp = lambda z : norm.logpdf(z,loc=0,scale=2)  +  cos(3*z)
    return target(logp,'saw_cos')

def target_halfsaw():
    logp1 = lambda z : norm.logpdf(z,loc=0,scale=2)  +  cos(3*z)
    logp2 = lambda z : norm.logpdf(z,loc=0,scale=2)  +  1
    logp = lambda z : (z<0)*logp1(z) + (z>=0)*logp2(z)
    return target(logp,'halfsaw')

def target_gaussmod():
    logp = lambda z : norm.logpdf(z,scale=2) + 1.5*sin(8*z + np.sign(z)*z**2)
    return target(logp,'gaussmod')

def target_separated():
    logp1 = lambda z : norm.logpdf(z,loc=-4,scale=.75)
    logp2 = lambda z : norm.logpdf(z,loc= 4,scale=.75)
    def logp(z):
        return logp_mix(logp1(z),logp2(z),.75)# + sin(10*z)
    ll = 0.0
    return target(logp,'separated',ll)

def target_multiscale():
    logp1 = lambda z : studt.logpdf(z,df=5,loc=-3,scale=.5)
    logp2 = lambda z : studt.logpdf(z,df=2,loc=3 ,scale= 3)
    logp3 = lambda z : studt.logpdf(z,df=5,loc=0 ,scale=.5)
    def logp4(z):
        return logp_mix(logp1(z),logp2(z),.5)# + sin(10*z)
    def logp(z):
        return logp_mix(logp4(z),logp3(z),.6666)# + sin(10*z)
    ll = 0.0
    return target(logp,'multiscale',ll)

def target_heavymix():
    logp1 = lambda z : studt.logpdf(z,df=2,loc=-2 ,scale=.5)
    logp2 = lambda z : studt.logpdf(z,df=2,loc=2  ,scale=.5)
    def logp(z):
        return logp_mix(logp1(z),logp2(z),.5)# + sin(10*z)
    ll = 0.0
    return target(logp,'heavymix',ll)

def target_heavytail():
    logp = lambda z : studt.logpdf(z,df=2,loc=0 ,scale=1)
    ll = 0.0
    return target(logp,'heavytail',ll)

def target_oddcouple():
    logp1 = lambda z : studt.logpdf(z,df=2,loc=4 ,scale=2)
    logp2 = lambda z : norm.logpdf(z,loc=-4  ,scale=.5)
    def logp(z):
        return logp_mix(logp1(z),logp2(z),.5)# + sin(10*z)
    ll = 0.0
    return target(logp,'oddcouple',ll)

def target_recursive_madness():
    seed(1)
    df1 = 2.1
    df2 = 5
    loc1 = -7+14*rand()
    loc2 = -7+14*rand()
    scale1 = 1.5
    scale2 = .5
    logp1 = lambda z : studt.logpdf(z,df=df1,loc=loc1,scale=scale1)
    logp2 = lambda z : studt.logpdf(z,df=df2,loc=loc2,scale=scale2)
    def logp3(z):
        return logp_mix(logp1(z),logp2(z),.6)
    def logp4(z):
        return logp_mix(logp3(-z-1),logp3(z+1),.4)
    def logp5(z):
        return logp_mix(logp4(z-8),logp4(-z+8),.4)
    ll = 0.0
    return target(lambda z : logp5(z+10),'recursive_madness',ll)

#targets = [target_norm(), target_mix(), target_half()]
#targets = [target_half()]
#targets = [target_halfhalf(),target_half(),target_mix()]
#targets = [target_norm(),target_mix(),target_tallwide()]
targets = [target_imbalanced()]#,target_recursive_madness(),target_oddcouple(),target_gaussmod(),target_heavytail(),target_heavymix(),target_multiscale(),target_separated(),target_norm(),target_mix(),target_tallwide(),target_saw_sin(),target_saw_cos(),target_halfsaw()]

if __name__ == '__main__':
    from IPython import embed
    from matplotlib import pyplot as plt
    x = arange(-7,7,.01)
    for n,t in enumerate(targets):
        plt.figure()
        plt.plot(x,exp(t.logp(x)))
        plt.savefig('target'+str(n)+'.png')
